/*
 * Copyright 2017 LinkedIn, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.linkedin.parseq.internal;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Consumer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.event.Level;

/**
 * This class allows active monitoring of ParSeq threads in order to detect situations where
 * due to programming error task is running for a very long time e.g. due to deadlock, blocking etc.
 * <p>
 * When monitoring is enabled each ParSeq thread has a {@code ThreadLocal<ExecutionMonitorState>} which
 * is updated whenever a {@code Runnable} is started and completed. This class starts a daemon monitoring thread
 * which checks status of all actively monitored threads. When thread is executing a {@code Runnable} for
 * a long time (as specified by {@code durationThresholdNano} parameter a log is emitted at specified log level.
 * Log contains full thread dump as well as state of all actively monitored threads before and after thread dump.
 * <p>
 * The intention of this class is to catch and help solving programming error, not problems related to a java
 * platform or underlying resources e.g. we try to avoid reporting long STW GCs. In order to achieve it the
 * monitoring thread tries to detect significant hiccups (situations when it was not woken up on time) and use it
 * as a signal that stall was caused by external factor such as long STW GC.
 * <p>
 * Monitoring mechanism implemented by this class tries to minimize overhead of ParSeq threads.
 *
 * @author Jaroslaw Odzga (jodzga@linkedin.com)
 *
 */
public class ExecutionMonitor {

  private static final Logger LOG = LoggerFactory.getLogger(ExecutionMonitor.class);
  private static final ThreadDumper THREAD_DUMPER = new ThreadDumper();
  private static final DecimalFormat DECIMAL_FORMAT = new DecimalFormat("#.###");

  private final ThreadLocal<ExecutionMonitorState> LOCAL_MONITOR = new ThreadLocal<ExecutionMonitorState>() {
    @Override
    protected ExecutionMonitorState initialValue() {
      ExecutionMonitorState state = new ExecutionMonitorState(Thread.currentThread().getId());
      return state;
    };
  };

  private final ConcurrentLinkedQueue<ExecutionMonitorState> _addedMonitors =
      new ConcurrentLinkedQueue<ExecutionMonitorState>();

  private final int _maxMonitors;
  private final long _durationThresholdNano;
  private final long _checkIntervalNano;
  private final long _idleDurationNano;
  private final long _loggingIntervalNano;
  private final long _minStallNano;
  private final int _stallsHistorySize;
  private final Consumer<String> _logger;
  private final Clock _clock;
  private volatile boolean _stopped = false;

  private final Set<ExecutionMonitorState> _monitors = new HashSet<>();
  private final TreeMap<Long, Long> _stalls = new TreeMap<>();
  private final Thread _monitoringThread;

  private long _lastMonitoringStep;
  private long _nextAllowedLogging;
  private long _shortestObservedDelta = Long.MAX_VALUE;

  /**
   * Creates an instance of ExecutionMonitor and starts monitoring daemon thread.
   *
   * @param maxMonitors maximum number of monitored threads. It should be greater than expected number of threads
   * in ParSeq thread pool. When number of active threads exceeds this number then warning is logged and some threads
   * will be ignored.
   * @param durationThresholdNano specifies the duration of execution that is considered as exceedingly long. Thread
   * executing {@code Runnable} longer than this value will trigger log event containing thread dump and state of
   * all actively monitored threads.
   * @param checkIntervalNano interval at which monitoring thread checks for long running tasks.
   * @param idleDurationNano specifies amount of time after which thread is removed from a list of actively monitored
   * threads if it has not been running any tasks. This allows detection of dead and inactive threads.
   * @param loggingIntervalNano specifies minimum amount of time between two log events generated by this class. This
   * is to avoid excessive logging because thread dumps can be large.
   * @param minStallNano specifies minimum amount of time which is considered a significant stall. In order to avoid
   * triggering thread dumps on events such as long STW GCs this class tries to identify stalls cause by external
   * factors by measuring the difference between time at which monitoring thread was woken up and scheduled wake up time.
   * The difference is considered to be a stall if it is larger than this parameter.
   * @param stallsHistorySize size of the maximum stalls history to be kept in memory
   * @param level level at which log events containing state of monitored threads and thread dump is generated.
   * @param clock clock implementation to be used.
   */
  public ExecutionMonitor(int maxMonitors, long durationThresholdNano, long checkIntervalNano, long idleDurationNano,
      long loggingIntervalNano, long minStallNano, int stallsHistorySize, Level level, Clock clock) {
    _maxMonitors = maxMonitors;
    _durationThresholdNano = durationThresholdNano;
    _checkIntervalNano = checkIntervalNano;
    _idleDurationNano = idleDurationNano;
    _loggingIntervalNano = loggingIntervalNano;
    _minStallNano = minStallNano;
    _stallsHistorySize = stallsHistorySize;
    _clock = clock;

    switch(level) {
      case INFO:
        _logger = LOG::info;
        break;
      case DEBUG:
        _logger = LOG::debug;
        break;
      case ERROR:
        _logger = LOG::error;
        break;
      case TRACE:
        _logger = LOG::trace;
        break;
      case WARN:
        _logger = LOG::warn;
        break;
      default:
        _logger = LOG::warn;
        break;
    }
    _monitoringThread = new Thread(this::monitor);
    _monitoringThread.setDaemon(true);
    _monitoringThread.setName("ParSeqExecutionMonitor");
    _monitoringThread.start();
  }

  ExecutionMonitorState getLocalMonitorState() {
    return LOCAL_MONITOR.get();
  }

  void shutdown() {
    _stopped = true;
  }

  /**
   * Main loop of monitoring thread.
   */
  private void monitor() {
    _lastMonitoringStep = _clock.nanoTime();
    _nextAllowedLogging = _lastMonitoringStep;
    while(!_stopped) {
      try {
        _clock.sleepNano(_checkIntervalNano);
      } catch (InterruptedException e) {
        break;
      }
      monitorStep();
    }
  }

  private void monitorStep() {
    long currentTime = _clock.nanoTime();

    checkForStall(currentTime);

    drainAddedMonitorsQueue();

    List<ExecutionMonitorState> toRemove = new ArrayList<>();

    long oldestTimestamp = currentTime;
    boolean thereAreLongRunningThreads = false;

    for (ExecutionMonitorState m : _monitors) {
      if (m._isActive) {
        //we can't write if (m._lastUpdate < oldestTimestamp) because of an overflow of signed long
        if (m._lastUpdate - oldestTimestamp < 0) {
          oldestTimestamp = m._lastUpdate;
        }

        /*
         * stallTime represents time that we assume thread was not responsible for.
         * This is a heuristic that allows us avoid reporting situations such as a long STW GC.
         */
        long stallTime = getStallsSince(m._lastUpdate);
        //
        long activeTime = (currentTime - m._lastUpdate) - stallTime;
        if (activeTime > _durationThresholdNano) {
          thereAreLongRunningThreads = true;
        }
      } else {
        if (currentTime - m._lastUpdate > _idleDurationNano) {
          toRemove.add(m);
        }
      }
    }

    for (ExecutionMonitorState m : toRemove) {
      _monitors.remove(m);
    }

    //generate log events if there was a long running thread and we don't exceed allowed logging frequency
    if (thereAreLongRunningThreads && currentTime - _nextAllowedLogging >= 0) {
      _nextAllowedLogging = currentTime + _loggingIntervalNano;
      logMonitoredThreads(_monitors);
    }

    //remove stalls that will not be used anymore
    trimStalls(oldestTimestamp);

    _lastMonitoringStep = _clock.nanoTime();
  }

  /**
   * Check how much we missed scheduled wake up and if it is larger than _minStallNano
   * then consider it a stall and remember it.
   */
  private void checkForStall(long currentTime) {
    long delta = currentTime - _lastMonitoringStep;
    if (delta < _shortestObservedDelta) {
      _shortestObservedDelta = delta;
    }
    long stall = Math.max(0, delta - _shortestObservedDelta);
    if (stall > _minStallNano) {
      _stalls.put(_lastMonitoringStep, stall);
      if (_stalls.size() > _stallsHistorySize) {
        _stalls.pollFirstEntry();
      }
    }
  }

  private void drainAddedMonitorsQueue() {
    ExecutionMonitorState monitor = null;
    do {
      monitor = _addedMonitors.poll();
      if (monitor != null) {
        if (_monitors.size() < _maxMonitors) {
          _monitors.add(monitor);
        } else {
          LOG.warn("Exceeded number of maximum monitored threads, thread with Id=" + monitor._threadId
              + " will not be monitored");
        }
      }
    } while(monitor != null);
  }

  private void logMonitoredThreads(Set<ExecutionMonitorState> monitoredThreads) {
    StringBuilder sb = new StringBuilder();

    sb.append("Found ParSeq threads running longer than ")
      .append(DECIMAL_FORMAT.format(((double) _durationThresholdNano) / 1000000))
      .append("ms.\n\nMonitored ParSeq threads before thread dump: \n");

    logMonitoredThreads(monitoredThreads, _clock.nanoTime(), sb);

    sb.append("\nThread dump:\n\n");

    THREAD_DUMPER.threadDump(sb);

    sb.append("Monitored ParSeq threads after thread dump: \n");

    logMonitoredThreads(monitoredThreads, _clock.nanoTime(), sb);

    _logger.accept(sb.toString());

  }

  private void logMonitoredThreads(Set<ExecutionMonitorState> monitoredThreads, long currentTime, StringBuilder sb) {
    for (ExecutionMonitorState m : monitoredThreads) {
      long runTime = Math.max(0,  currentTime - m._lastUpdate);
      if (runTime > _durationThresholdNano && m._isActive) {
        sb.append("(!) ");
      } else {
        sb.append("    ");
      }
      sb.append("Thread Id=")
      .append(m._threadId)
      .append(m._isActive ? " busy for " : " idle for ")
      .append(DECIMAL_FORMAT.format(((double)runTime) / 1000000))
      .append("ms\n");
    }
  }

  private void trimStalls(long oldestTimestamp) {
    while (!_stalls.isEmpty() && _stalls.firstKey() < oldestTimestamp) {
      _stalls.remove(_stalls.firstKey());
    }
  }

  long getStallsSince(long lastUpdate) {
    long stall = 0;
    Entry<Long, Long> entry = _stalls.ceilingEntry(lastUpdate);
    while(entry != null) {
      stall += entry.getValue();
      entry = _stalls.higherEntry(entry.getKey());
    }
    return stall;
  }

  class ExecutionMonitorState {

    private final long _threadId;
    private volatile long _lastUpdate = 0;
    private volatile boolean _isActive = false;
    private volatile boolean _isMonitored = false;

    public ExecutionMonitorState(long threadId) {
      _threadId = threadId;
    }

    public void activate() {
      _lastUpdate = _clock.nanoTime();
      _isActive = true;
      if (!_isMonitored) {
        _isMonitored = true;
        _addedMonitors.add(this);
      }
    }

    public void deactivate() {
      _lastUpdate = _clock.nanoTime();
      _isActive = false;
    }

    @Override
    public int hashCode() {
      final int prime = 31;
      int result = 1;
      result = prime * result + (int) (_threadId ^ (_threadId >>> 32));
      return result;
    }

    @Override
    public boolean equals(Object obj) {
      if (this == obj)
        return true;
      if (obj == null)
        return false;
      if (getClass() != obj.getClass())
        return false;
      ExecutionMonitorState other = (ExecutionMonitorState) obj;
      if (_threadId != other._threadId)
        return false;
      return true;
    }
  }
}
